{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "65811579",
   "metadata": {},
   "source": [
    "# L5 Closed-loop Gym-compatible Environment\n",
    "\n",
    "This notebook demonstrates some of the aspects of our gym-compatible closed-loop environment.\n",
    "\n",
    "You will understand the inner workings of our L5Kit environment and an RL policy can be used to rollout the environment. \n",
    "\n",
    "Note: The training of different RL policies in our environment will be shown in a separate notebook.\n",
    "\n",
    "![drivergym](../../docs/images/rl/drivergym.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69ad643d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#@title Download L5 Sample Dataset and install L5Kit\n",
    "import os\n",
    "RunningInCOLAB = 'google.colab' in str(get_ipython())\n",
    "if RunningInCOLAB:\n",
    "    !wget https://raw.githubusercontent.com/lyft/l5kit/master/examples/setup_notebook_colab.sh -q\n",
    "    !sh ./setup_notebook_colab.sh\n",
    "    os.environ[\"L5KIT_DATA_FOLDER\"] = open(\"./dataset_dir.txt\", \"r\").read().strip()\n",
    "else:\n",
    "    os.environ[\"L5KIT_DATA_FOLDER\"] = \"/tmp/level5_data\"\n",
    "    print(\"Not running in Google Colab.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7eebda10",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "import l5kit.environment\n",
    "from l5kit.configs import load_config_data\n",
    "from l5kit.environment.envs.l5_env import EpisodeOutputGym, SimulationConfigGym\n",
    "from l5kit.environment.gym_metric_set import L2DisplacementYawMetricSet\n",
    "from l5kit.visualization.visualizer.zarr_utils import episode_out_to_visualizer_scene_gym_cle\n",
    "from l5kit.visualization.visualizer.visualizer import visualize\n",
    "\n",
    "from bokeh.io import output_notebook, show\n",
    "from prettytable import PrettyTable\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "551fc59b",
   "metadata": {},
   "source": [
    "### First, let's configure where our data lives!\n",
    "The data is expected to live in a folder that can be configured using the `L5KIT_DATA_FOLDER` env variable. Your data folder is expected to contain subfolders for the aerial and semantic maps as well as the scenes (`.zarr` files). \n",
    "In this example, the env variable is set to the local data folder. You should make sure the path points to the correct location for you.\n",
    "\n",
    "We built our code to work with a human-readable `yaml` config. This config file holds much useful information, however, we will only focus on a few functionalities concerning the creation of our gym environment here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "864884da",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataset is assumed to be on the folder specified\n",
    "# in the L5KIT_DATA_FOLDER environment variable\n",
    "\n",
    "# get environment config\n",
    "env_config_path = '../gym_config.yaml'\n",
    "cfg = load_config_data(env_config_path)\n",
    "print(cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3fba368",
   "metadata": {},
   "source": [
    "### We can look into our current configuration for interesting fields\n",
    "\n",
    "\\- when loaded in python, the `yaml`file is converted into a python `dict`. \n",
    "\n",
    "`raster_params` contains all the information related to the transformation of the 3D world onto an image plane:\n",
    "  - `raster_size`: the image plane size\n",
    "  - `pixel_size`: how many meters correspond to a pixel\n",
    "  - `ego_center`: our raster is centered around an agent, we can move the agent in the image plane with this param\n",
    "  - `map_type`: the rasterizer to be employed. We currently support a satellite-based and a semantic-based one. We will look at the differences further down in this script\n",
    "  \n",
    "The `raster_params` are used to determine the observation provided by our gym environment to the RL policy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d631092",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'current raster_param:\\n')\n",
    "for k,v in cfg[\"raster_params\"].items():\n",
    "    print(f\"{k}:{v}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61f3c2f3",
   "metadata": {},
   "source": [
    "## Create L5 Closed-loop Environment\n",
    "\n",
    "We will now create an instance of the L5Kit gym-compatible environment. As you can see, we need to provide the path to the configuration file of the environment. \n",
    "\n",
    "1. The `rescale_action` flag rescales the policy action based on dataset statistics. This argument helps for faster convergence during policy training. \n",
    "2. The `return_info` flag informs the environment to return the episode output everytime an episode is rolled out. \n",
    "\n",
    "Note: The environment has already been registered with gym during initialization of L5Kit.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd8efa0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make(\"L5-CLE-v0\", env_config_path=env_config_path, rescale_action=False, return_info=True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5b49668",
   "metadata": {},
   "source": [
    "## Visualize an observation from the environment\n",
    "\n",
    "Let us visualize the observation from the environment. We will reset the environment and visualize an observation which is provided by the environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39fbbb48",
   "metadata": {},
   "outputs": [],
   "source": [
    "obs = env.reset()\n",
    "im = obs[\"image\"].transpose(1, 2, 0)\n",
    "im = env.dataset.rasterizer.to_rgb(im)\n",
    "\n",
    "plt.imshow(im)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4bb6922",
   "metadata": {},
   "source": [
    "## Rollout an episode from the environment\n",
    "\n",
    "### The rollout of an episode in our environment takes place in three steps:\n",
    "\n",
    "### Gym Environment Update:\n",
    "1. Reward Calculation (CLE): Given an action from the policy, the environment will calculate the reward received as a consequence of the action.\n",
    "2. Internal State Update: Since we are rolling out the environment in closed-loop, the internal state of the ego is updated based on the action.\n",
    "3. Raster rendering: A new raster image is rendered based on the predicted ego position and returned as the observation of next time-step.\n",
    "\n",
    "### Policy Forward Pass\n",
    "The policy takes as input the observation provided by the environment and outputs the action via a forward pass.\n",
    "\n",
    "### Inter-process communication\n",
    "Usually, we deploy different subprocesses to rollout parallel environments to speed up rollout time during training. Each subprocess rolls out one environemnt. In such scenarios, there is an additional component called inter-process communication: The subprocess outputs (observations) are aggregated and passed to the main process and vice versa (for the actions)\n",
    "\n",
    "![rollout](../../docs/images/rl/policy_rollout.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a2e88cb",
   "metadata": {},
   "source": [
    "### Dummy Policy\n",
    "\n",
    "For this notebook, we will not train the policy but use a dummy policy. Our dummy policy that will move the ego by 10 m/s along the direction of orientation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb4b7249",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DummyPolicy(torch.nn.Module):\n",
    "    \"\"\"A policy that advances the ego by constant speed along x-direction.\n",
    "\n",
    "    :param advance_x: the distance to advance per time-step\n",
    "    \"\"\"\n",
    "    def __init__(self, advance_x: float = 0.0):\n",
    "        super(DummyPolicy, self).__init__()\n",
    "        self.advance_x = advance_x\n",
    "\n",
    "    def forward(self, x):\n",
    "        positions_and_yaws = torch.zeros(3,)\n",
    "        positions_and_yaws[..., 0] = self.advance_x\n",
    "\n",
    "        return positions_and_yaws.cpu().numpy()\n",
    "\n",
    "# We multiple the desired speed by the step-time (inverse of frequency) of data collection\n",
    "desired_speed = 10.0\n",
    "dummy_policy = DummyPolicy(cfg[\"model_params\"][\"step_time\"] * desired_speed)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f172f6b5",
   "metadata": {},
   "source": [
    "Let us now rollout the environment using the dummy policy. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b1a57a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rollout_episode(env, idx = 0):\n",
    "    \"\"\"Rollout a particular scene index and return the simulation output.\n",
    "\n",
    "    :param env: the gym environment\n",
    "    :param idx: the scene index to be rolled out\n",
    "    :return: the episode output of the rolled out scene\n",
    "    \"\"\"\n",
    "\n",
    "    # Set the reset_scene_id to 'idx'\n",
    "    env.reset_scene_id = idx\n",
    "    \n",
    "    # Rollout step-by-step\n",
    "    obs = env.reset()\n",
    "    while True:\n",
    "        action = dummy_policy(obs)\n",
    "        obs, _, done, info = env.step(action)\n",
    "        if done:\n",
    "            break\n",
    "    \n",
    "    # The episode outputs are present in the key \"sim_outs\"\n",
    "    sim_out = info[\"sim_outs\"][0]\n",
    "    return sim_out\n",
    "\n",
    "# Rollout one episode\n",
    "sim_out = rollout_episode(env)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "808aaded",
   "metadata": {},
   "source": [
    "## Visualize the episode from the environment\n",
    "\n",
    "We can easily visualize the outputs obtained by rolling out episodes in the L5Kit using the Bokeh visualizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41b124a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# might change with different rasterizer\n",
    "map_API = env.dataset.rasterizer.sem_rast.mapAPI\n",
    "\n",
    "def visualize_outputs(sim_outs, map_API):\n",
    "    for sim_out in sim_outs: # for each scene\n",
    "        vis_in = episode_out_to_visualizer_scene_gym_cle(sim_out, map_API)\n",
    "        show(visualize(sim_out.scene_id, vis_in))\n",
    "\n",
    "output_notebook()\n",
    "visualize_outputs([sim_out], map_API)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39f51d86",
   "metadata": {},
   "source": [
    "## Calculate the performance metrics from the episode outputs\n",
    "\n",
    "We can also calculate the various quantitative metrics on the rolled out episode output. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98c31015",
   "metadata": {},
   "outputs": [],
   "source": [
    "def quantify_outputs(sim_outs, metric_set=None):\n",
    "    metric_set = metric_set if metric_set is not None else L2DisplacementYawMetricSet()\n",
    "\n",
    "    metric_set.evaluate(sim_outs)\n",
    "    scene_results = metric_set.evaluator.scene_metric_results\n",
    "    fields = [\"scene_id\", \"FDE\", \"ADE\"]\n",
    "    table = PrettyTable(field_names=fields)\n",
    "    tot_fde = 0.0\n",
    "    tot_ade = 0.0\n",
    "    for scene_id in scene_results:\n",
    "        scene_metrics = scene_results[scene_id]\n",
    "        ade_error = scene_metrics[\"displacement_error_l2\"][1:].mean()\n",
    "        fde_error = scene_metrics['displacement_error_l2'][-1]\n",
    "        table.add_row([scene_id, round(fde_error.item(), 4), round(ade_error.item(), 4)])\n",
    "        tot_fde += fde_error.item()\n",
    "        tot_ade += ade_error.item()\n",
    "\n",
    "    ave_fde = tot_fde / len(scene_results)\n",
    "    ave_ade = tot_ade / len(scene_results)\n",
    "    table.add_row([\"Overall\", round(ave_fde, 4), round(ave_ade, 4)])\n",
    "    print(table)\n",
    "\n",
    "\n",
    "quantify_outputs([sim_out])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
